// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <xnnpack/assembly.h>

.syntax unified

// void xnn_f32_igemm_minmax_ukernel_4x8__aarch32_neon_cortex_a55(
//     size_t mr,                            r0
//     size_t nc,                            r1
//     size_t kc,                            r2 -> r5
//     size_t ks,                            r3 -> sp + 64 -> r14
//     const float**restrict a,  sp + 104 -> (r5)
//     const void*restrict w,    sp + 108 -> r9
//     uint8_t*restrict c,       sp + 112 -> r11
//     size_t cm_stride,         sp + 116 -> (r6)
//     size_t cn_stride,         sp + 120 -> (r0)
//     size_t a_offset,          sp + 124 -> (r5)
//     const float* zero,        sp + 128 -> (r0)
//     minmax_params*params,     sp + 132 -> (r5)

// inner loop registers

// A0   r3  d0
// A1  r12  d1
// A2  r10  d2
// A3   r7  d3

// B    r9  d8,  d9, d10, d11
// B       d12, d13, d14, d15

// C0  r11 d16-d17  q8  d18-d19  q9
// C1   r4 d20-d21 q10  d22-d23 q11
// C2   r8 d24-d25 q12  d26-d27 q13
// C3   r6 d28-d29 q14  d30-d31 q15

// Clamp (r5) d4 d5 d6 d7

BEGIN_FUNCTION xnn_f32_igemm_minmax_ukernel_4x8__aarch32_neon_cortex_a55
        .arm
#ifndef __APPLE__
        .arch   armv7-a
        .fpu    neon
#endif
        # Push 104 bytes
        PUSH    {r3, r4, r5, r6, r7, r8, r9, r10, r11, r14}  // +40
        VPUSH   {d8-d15}                                     // +64 = 104

        LDR     r11, [sp, 112]          // c
        LDR     r6, [sp, 116]           // cm_stride
        LDR     r5, [sp, 104]           // a
        LDR     r9, [sp, 108]           // w
        MOV     r14, r3                 // p = ks

        # Clamp C pointers
        CMP     r0, 2                   // if mr >= 2
        ADD     r4, r11, r6             //   c1 = c0 + cm_stride
        MOVLO   r4, r11                 // c1
                                     // if mr > 2
        ADD     r8, r4, r6              //   c2 = c1 + cm_stride
        MOVLS   r8, r4                  // c2
        CMP     r0, 4                   // if mr >=4
        ADD     r6, r8, r6              //   c3 = c2 + cm_stride
        MOVLO   r6, r8                  // c3


        .p2align 3
0:
        # Load initial bias from w into accumulators
        VLDM    r9!, {d16-d19}          // Bias

        VMOV    q10, q8
        VMOV    q11, q9
        VMOV    q12, q8
        VMOV    q13, q9
        PLD     [r9,   0]               // Prefetch B
        PLD     [r9,  64]
        VMOV    q14, q8
        PLD     [r9, 128]
        PLD     [r9, 192]
        VMOV    q15, q9
        PLD     [r9, 256]
        PLD     [r9, 320]

1:
        # Load next 4 A pointers
        LDR     r3, [r5,  0]
        LDR     r12, [r5,  4]
        LDR     r10, [r5,  8]
        LDR     r7, [r5, 12]
        ADD     r5, r5, 16
        PLD     [r3,  0]                // Prefetch A
        STR     r5, [sp, 104]           // a
        PLD     [r3, 64]
        LDR     r0, [sp, 128]           // zero
        PLD     [r12,  0]
        LDR     r5, [sp, 124]           // a_offset
        PLD     [r12, 64]
        PLD     [r10,  0]
        PLD     [r10, 64]
        PLD     [r7,  0]
        PLD     [r7, 64]

        # Add a_offset
        CMP     r3,  r0                 // if a0 == zero
        ADD     r3,  r3, r5             // a0 += a_offset
        MOVEQ   r3,  r0                 //   a0 = zero, else += a0 + a_offset
        CMP     r12,  r0                // if a1 == zero
        ADD     r12, r12, r5            // a1 += a_offset
        MOVEQ   r12,  r0                //   a1 = zero, else += a1 + a_offset
        CMP     r10,  r0                // if a2 == zero
        ADD     r10, r10, r5            // a2 += a_offset
        MOVEQ   r10,  r0                //   a2 = zero, else += a2 + a_offset
        CMP     r7,  r0                 // if a3 == zero
        ADD     r7,  r7, r5             // a3 += a_offset
        MOVEQ   r7,  r0                 //   a3 = zero, else += a3 + a_offset

        SUBS    r5, r2, 16              // kc - 16
        BLO     5f                      // less than 4 channels?

        # Prologue
        VLD1.32 {d0},  [r3]!            // A0
        VLD1.32 {d1}, [r12]!            // A1
        VLD1.32 {d2}, [r10]!            // A2
        VLD1.32 {d3},  [r7]!            // A3
        SUBS    r5, r5, 16
        VLDM    r9, {d8-d11}            // B0
        VLDR    d15, [r9, 56]           // B1CK 0
        VLDR    d13, [r9, 40]           // B1

        BLO     3f                      // less than 4 channels?  skip main loop

        # Main loop - 4 floats of A (16 bytes)
        # 32 FMA + 8 LD64 A + 8 LDR B
        .p2align 3
2:
        # First group of 16 FMA, Second group loads
        # BLOCK 0
        VMLA.F32 q8, q4, d0[0]
        VLD1.32 {d4}, [r3]!             // A0
        VMLA.F32 q10, q4, d1[0]
        VLD1.32 {d5}, [r12]!            // A1
        VMLA.F32 q12, q4, d2[0]

        # BLOCK 1
        VMLA.F32 q14, q4, d3[0]
        VLDR    d12, [r9, 32]           // B1
        VMLA.F32 q9, q5, d0[0]
        VLDR    d9, [r9, 72]            // B0
        VMLA.F32 q11, q5, d1[0]

        # BLOCK 2
        VMLA.F32 q13, q5, d2[0]
        VLD1.32 {d6}, [r10]!            // A2
        VMLA.F32 q15, q5, d3[0]
        VLD1.32 {d7}, [r7]!             // A3
        VMLA.F32 q8, q6, d0[1]

        # BLOCK 3
        VMLA.F32 q10, q6, d1[1]
        VLDR    d14, [r9, 48]           // B1
        VMLA.F32 q12, q6, d2[1]
        VLDR    d11, [r9, 88]           // B0
        VMLA.F32 q14, q6, d3[1]

        # BLOCK 4
        VMLA.F32 q9, q7, d0[1]
        VLDR    d8, [r9, 64]            // B0
        VMLA.F32 q11, q7, d1[1]
        VLDR    d13, [r9, 104]          // B1
        VMLA.F32 q13, q7, d2[1]
        VLDR    d10, [r9, 80]           // B0

        # BLOCK 5
        VMLA.F32 q15, q7, d3[1]
        VLDR    d15, [r9, 120]          // B1

        # Second group of 16 FMA, First group of loads
        # BLOCK 0
        VMLA.F32 q8, q4, d4[0]
        VLD1.32 {d0}, [r3]!             // A0
        VMLA.F32 q10, q4, d5[0]
        VLD1.32 {d1}, [r12]!            // A1
        VMLA.F32 q12, q4, d6[0]

        # BLOCK 1
        VMLA.F32 q14, q4, d7[0]
        VLDR    d12, [r9, 96]           // B1
        VMLA.F32 q9, q5, d4[0]
        VLDR    d9, [r9, 136]           // B0
        VMLA.F32 q11, q5, d5[0]

        # BLOCK 2
        VMLA.F32 q13, q5, d6[0]
        VLD1.32 {d2}, [r10]!            // A2
        VMLA.F32 q15, q5, d7[0]
        VLD1.32 {d3}, [r7]!             // A3
        VMLA.F32 q8, q6, d4[1]
        SUBS    r5, r5, 16

        # BLOCK 3
        VMLA.F32 q10, q6, d5[1]
        VLDR    d14, [r9, 112]          // B1
        VMLA.F32 q12, q6, d6[1]
        VLDR    d11, [r9, 152]          // B0
        VMLA.F32 q14, q6, d7[1]

        # BLOCK 4
        VMLA.F32 q9, q7, d4[1]
        VLDR    d8, [r9, 128]           // B0
        VMLA.F32 q11, q7, d5[1]
        VLDR    d13, [r9, 168]          // B1
        VMLA.F32 q13, q7, d6[1]
        VLDR    d10, [r9, 144]          // B0

        # BLOCK 5
        VMLA.F32 q15, q7, d7[1]
        VLDR    d15, [r9, 184]          // B1
        ADD     r9, r9, 128             // B++
        BHS     2b

        # Epilogue - 4 floats of A (16 bytes)
3:
        # First group of 16 FMA, Second group loads
        # BLOCK 0
        VMLA.F32 q8, q4, d0[0]
        VLD1.32 {d4}, [r3]!             // A0
        VMLA.F32 q10, q4, d1[0]
        VLD1.32 {d5}, [r12]!            // A1
        VMLA.F32 q12, q4, d2[0]

        # BLOCK 1
        VMLA.F32 q14, q4, d3[0]
        VLDR    d12, [r9, 32]           // B1
        VMLA.F32 q9, q5, d0[0]
        VLDR    d9, [r9, 72]            // B0
        VMLA.F32 q11, q5, d1[0]

        # BLOCK 2
        VMLA.F32 q13, q5, d2[0]
        VLD1.32 {d6}, [r10]!            // A2
        VMLA.F32 q15, q5, d3[0]
        VLD1.32 {d7}, [r7]!             // A3
        VMLA.F32 q8, q6, d0[1]

        # BLOCK 3
        VMLA.F32 q10, q6, d1[1]
        VLDR    d14, [r9, 48]           // B1
        VMLA.F32 q12, q6, d2[1]
        VLDR    d11, [r9, 88]           // B0
        VMLA.F32 q14, q6, d3[1]

        # BLOCK 4
        VMLA.F32 q9, q7, d0[1]
        VLDR    d8, [r9, 64]            // B0
        VMLA.F32 q11, q7, d1[1]
        VLDR    d13, [r9, 104]          // B1
        VMLA.F32 q13, q7, d2[1]
        VLDR    d10, [r9, 80]           // B0

        # BLOCK 5
        VMLA.F32 q15, q7, d3[1]
        VLDR    d15, [r9, 120]          // B1

        # Second group of 16 FMA, First group of loads
        # BLOCK 0
        VMLA.F32 q8, q4, d4[0]
        VLDR    d12, [r9, 96]           // B1
        VMLA.F32 q10, q4, d5[0]
        VMLA.F32 q12, q4, d6[0]

        # BLOCK 1
        VMLA.F32 q14, q4, d7[0]
        VLDR    d14, [r9, 112]          // B1
        VMLA.F32 q9, q5, d4[0]
        VMLA.F32 q11, q5, d5[0]

        # BLOCK 2
        VMLA.F32 q13, q5, d6[0]
        VMLA.F32 q15, q5, d7[0]
        VMLA.F32 q8, q6, d4[1]
        ADD     r9, r9, 128             // B++

        # BLOCK 3
        VMLA.F32 q10, q6, d5[1]
        VMLA.F32 q12, q6, d6[1]
        VMLA.F32 q14, q6, d7[1]
        TST     r5, 15

        # BLOCK 4
        VMLA.F32 q9, q7, d4[1]
        VMLA.F32 q11, q7, d5[1]
        VMLA.F32 q13, q7, d6[1]

        # BLOCK 5
        VMLA.F32 q15, q7, d7[1]

        # Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
        BNE     5f

        .p2align 3
4:
        LDR     r5, [sp, 104]           // a
        SUBS    r14, r14, 16            // ks -= MR * sizeof(void*)

        # ks loop
        BHI     1b

        # Load params pointer
        LDR     r0, [sp, 132]           // params
        LDR     r14, [sp, 64]           // p = ks
        # Load min/max values
        VLD1.32 {d4[],d5[]}, [r0]!
        VLD1.32 {d6[],d7[]}, [r0]
        SUBS    r1, r1, 8
        LDR     r0, [sp, 120]           // cn_stride

        # Clamp
        VMAX.F32 q8,  q8, q2
        VMAX.F32 q9,  q9, q2
        VMAX.F32 q10, q10, q2
        VMAX.F32 q11, q11, q2
        VMAX.F32 q12, q12, q2
        VMAX.F32 q13, q13, q2
        VMAX.F32 q14, q14, q2
        VMAX.F32 q15, q15, q2
        VMIN.F32 q8,  q8, q3
        VMIN.F32 q9,  q9, q3
        VMIN.F32 q10, q10, q3
        VMIN.F32 q11, q11, q3
        VMIN.F32 q12, q12, q3
        VMIN.F32 q13, q13, q3
        VMIN.F32 q14, q14, q3
        VMIN.F32 q15, q15, q3

        # Store full 4 x 8
        BLO     7f
        VST1.32 {d28-d31},  [r6], r0
        VST1.32 {d24-d27},  [r8], r0
        VST1.32 {d20-d23},  [r4], r0
        VST1.32 {d16-d19}, [r11], r0

        SUB     r5, r5, r14             // a -= ks

        BHI     0b

        VPOP    {d8-d15}
        ADD     sp, sp, 4               // skip r3
        POP     {r4, r5, r6, r7, r8, r9, r10, r11, pc}

        .p2align 3
5:
        # Is there a remainder?- 2 floats of A (8 bytes)
        TST     r5, 8
        BEQ     6f

        # Remainder - 2 floats of A (8 bytes)
        VLD1.32 {d0}, [r3]!             // A0
        VLDM    r9!, {d8-d11}           // B0
        VLD1.32 {d1}, [r12]!            // A1
        VLD1.32 {d2}, [r10]!            // A2
        VLD1.32 {d3}, [ r7]!            // A3

        VMLA.F32 q8, q4, d0[0]
        VMLA.F32 q9, q5, d0[0]
        VMLA.F32 q10, q4, d1[0]
        VMLA.F32 q11, q5, d1[0]
        VLDM    r9!, {d12-d15}          // B1
        VMLA.F32 q12, q4, d2[0]
        VMLA.F32 q13, q5, d2[0]
        VMLA.F32 q14, q4, d3[0]
        VMLA.F32 q15, q5, d3[0]
        VMLA.F32 q8, q6, d0[1]
        VMLA.F32 q9, q7, d0[1]
        VMLA.F32 q10, q6, d1[1]
        VMLA.F32 q11, q7, d1[1]
        VMLA.F32 q12, q6, d2[1]
        VMLA.F32 q13, q7, d2[1]
        VMLA.F32 q14, q6, d3[1]
        VMLA.F32 q15, q7, d3[1]

        # Is there a remainder?- 1 floats of A (4 bytes)
        TST     r5, 4
        BEQ     4b

6:
        # Remainder- 1 floats of A (4 bytes)
        VLDM    r3!, {s0}               // A0
        VLDM    r9!, {d8-d11}           // B0
        VLDM    r12!, {s2}              // A1
        VLDM    r10!, {s4}              // A2
        VLDM    r7!, {s6}               // A3
        VMLA.F32 q8, q4, d0[0]
        VMLA.F32 q9, q5, d0[0]
        VMLA.F32 q10, q4, d1[0]
        VMLA.F32 q11, q5, d1[0]
        VMLA.F32 q12, q4, d2[0]
        VMLA.F32 q13, q5, d2[0]
        VMLA.F32 q14, q4, d3[0]
        VMLA.F32 q15, q5, d3[0]
        B       4b

        # Store odd width
7:
        TST     r1, 4
        BEQ     8f
        VST1.32 {d28-d29},  [r6]!
        VST1.32 {d24-d25},  [r8]!
        VMOV    q14, q15
        VMOV    q12, q13
        VST1.32 {d20-d21},  [r4]!
        VST1.32 {d16-d17}, [r11]!
        VMOV    q10, q11
        VMOV    q8,  q9

8:
        TST     r1, 2
        BEQ     9f
        VST1.32 {d28},  [r6]!
        VST1.32 {d24},  [r8]!
        VMOV    d28, d29
        VMOV    d24, d25
        VST1.32 {d20},  [r4]!
        VST1.32 {d16}, [r11]!
        VMOV    d20, d21
        VMOV    d16, d17

9:
        TST     r1, 1
        BEQ     10f
        VST1.32 {d28[0]},  [r6]!
        VST1.32 {d24[0]},  [r8]!
        VST1.32 {d20[0]},  [r4]!
        VST1.32 {d16[0]}, [r11]!

10:
        VPOP    {d8-d15}
        ADD     sp, sp, 4               // skip r3
        POP     {r4, r5, r6, r7, r8, r9, r10, r11, pc}

END_FUNCTION xnn_f32_igemm_minmax_ukernel_4x8__aarch32_neon_cortex_a55

#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif
